{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature processing with Spark, training with BlazingText and deploying as Inference Pipeline\n",
    "\n",
    "Typically a Machine Learning (ML) process consists of few steps: gathering data with various ETL jobs, pre-processing the data, featurizing the dataset by incorporating standard techniques or prior knowledge, and finally training an ML model using an algorithm.\n",
    "\n",
    "In many cases, when the trained model is used for processing real time or batch prediction requests, the model receives data in a format which needs to pre-processed (e.g. featurized) before it can be passed to the algorithm. In the following notebook, we will demonstrate how you can build your ML Pipeline leveraging Spark Feature Transformers and SageMaker BlazingText algorithm & after the model is trained, deploy the Pipeline (Feature Transformer and BlazingText) as an Inference Pipeline behind a single Endpoint for real-time inference and for batch inferences using Amazon SageMaker Batch Transform.\n",
    "\n",
    "In this notebook, we use Amazon Glue to run serverless Spark. Though the notebook demonstrates the end-to-end flow on a small dataset, the setup can be seamlessly used to scale to larger datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Objective: Text Classification on DBPedia dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we will train the text classification model using SageMaker `BlazingText` algorithm on the [DBPedia Ontology Dataset](https://wiki.dbpedia.org/services-resources/dbpedia-data-set-2014#2) as done by [Zhang et al](https://arxiv.org/pdf/1509.01626.pdf). \n",
    "\n",
    "The DBpedia ontology dataset is constructed by picking 14 nonoverlapping classes from DBpedia 2014. It has 560,000 training samples and 70,000 testing samples. The fields we used for this dataset contain title and abstract of each Wikipedia article.\n",
    "\n",
    "\n",
    "Before passing the input data to `BlazingText`, we need to process this dataset into white-space separated tokens, have the label field in every line prefixed with `__label__` and all input data should be in a single file."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Methodologies\n",
    "The Notebook consists of a few high-level steps:\n",
    "\n",
    "* Using AWS Glue for executing the SparkML feature processing job.\n",
    "* Using SageMaker BlazingText to train on the processed dataset produced by SparkML job.\n",
    "* Building an Inference Pipeline consisting of SparkML & BlazingText models for a realtime inference endpoint.\n",
    "* Building an Inference Pipeline consisting of SparkML & BlazingText models for a single Batch Transform job."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using AWS Glue for executing Spark jobs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be running the SparkML job using [AWS Glue](https://aws.amazon.com/glue). AWS Glue is a serverless ETL service which can be used to execute standard Spark/PySpark jobs. Glue currently only supports `Python 2.7`, hence we'll write the script in `Python 2.7`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Permission setup for invoking AWS Glue from this Notebook\n",
    "In order to enable this Notebook to run AWS Glue jobs, we need to add one additional permission to the default execution role of this notebook. We will be using SageMaker Python SDK to retrieve the default execution role and then you have to go to [IAM Dashboard](https://console.aws.amazon.com/iam/home) to edit the Role to add AWS Glue specific permission."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finding out the current execution role of the Notebook\n",
    "We are using SageMaker Python SDK to retrieve the current role for this Notebook which needs to be enhanced."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import SageMaker Python SDK to get the Session and execution_role\n",
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "sess = sagemaker.Session()\n",
    "role = get_execution_role()\n",
    "print(role[role.rfind('/') + 1:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adding AWS Glue as an additional trusted entity to this role\n",
    "This step is needed if you want to pass the execution role of this Notebook while calling Glue APIs as well without creating an additional **Role**. If you have not used AWS Glue before, then this step is mandatory. \n",
    "\n",
    "If you have used AWS Glue previously, then you should have an already existing role that can be used to invoke Glue APIs. In that case, you can pass that role while calling Glue (later in this notebook) and skip this next step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the IAM dashboard, please click on **Roles** on the left sidenav and search for this Role. Once the Role appears, click on the Role to go to its **Summary** page. Click on the **Trust relationships** tab on the **Summary** page to add AWS Glue as an additional trusted entity. \n",
    "\n",
    "Click on **Edit trust relationship** and replace the JSON with this JSON.\n",
    "```\n",
    "{\n",
    "  \"Version\": \"2012-10-17\",\n",
    "  \"Statement\": [\n",
    "    {\n",
    "      \"Effect\": \"Allow\",\n",
    "      \"Principal\": {\n",
    "        \"Service\": [\n",
    "          \"sagemaker.amazonaws.com\",\n",
    "          \"glue.amazonaws.com\"\n",
    "        ]\n",
    "      },\n",
    "      \"Action\": \"sts:AssumeRole\"\n",
    "    }\n",
    "  ]\n",
    "}\n",
    "```\n",
    "Once this is complete, click on **Update Trust Policy** and you are done."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Downloading dataset and uploading to S3\n",
    "SageMaker team has downloaded the dataset and uploaded to one of the S3 buckets in our account. In this notebook, we will download from that bucket and upload to your bucket so that AWS Glue can access the data. The default AWS Glue permissions we just added expects the data to be present in a bucket with the string `aws-glue`. Hence, after we download the dataset, we will create an S3 bucket in your account with a valid name and then upload the data to S3. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://s3-us-west-2.amazonaws.com/sparkml-mleap/data/dbpedia/train.csv\n",
    "!wget https://s3-us-west-2.amazonaws.com/sparkml-mleap/data/dbpedia/test.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating an S3 bucket and uploading this dataset\n",
    "Next we will create an S3 bucket with the `aws-glue` string in the name and upload this data to the S3 bucket. In case you want to use some existing bucket to run your Spark job via AWS Glue, you can use that bucket to upload your data provided the `Role` has access permission to upload and download from that bucket.\n",
    "\n",
    "Once the bucket is created, the following cell would also update the `train.csv` and `test.csv` files downloaded locally to this bucket under the `input/dbpedia` prefix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import botocore\n",
    "from botocore.exceptions import ClientError\n",
    "\n",
    "boto_session = sess.boto_session\n",
    "s3 = boto_session.resource('s3')\n",
    "account = boto_session.client('sts').get_caller_identity()['Account']\n",
    "region = boto_session.region_name\n",
    "default_bucket = 'aws-glue-{}-{}'.format(account, region)\n",
    "\n",
    "try:\n",
    "    if region == 'us-east-1':\n",
    "        s3.create_bucket(Bucket=default_bucket)\n",
    "    else:\n",
    "        s3.create_bucket(Bucket=default_bucket, CreateBucketConfiguration={'LocationConstraint': region})\n",
    "except ClientError as e:\n",
    "    error_code = e.response['Error']['Code']\n",
    "    message = e.response['Error']['Message']\n",
    "    if error_code == 'BucketAlreadyOwnedByYou':\n",
    "        print ('A bucket with the same name already exists in your account - using the same bucket.')\n",
    "        pass        \n",
    "\n",
    "# Uploading the training data to S3\n",
    "sess.upload_data(path='train.csv', bucket=default_bucket, key_prefix='input/dbpedia')    \n",
    "sess.upload_data(path='test.csv', bucket=default_bucket, key_prefix='input/dbpedia')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing the feature processing script using SparkML\n",
    "\n",
    "The code for feature transformation using SparkML can be found in `dbpedia_processing.py` file written in the same directory. You can go through the code itself to see how it is using standard SparkML feature transformers to define the Pipeline for featurizing and processing the data.\n",
    "\n",
    "Once the Spark ML Pipeline `fit` and `transform` is done, we are tranforming the `train` and `test` file and writing it in the format `BlazingText` expects before uploading to S3."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Serializing the trained Spark ML Model with [MLeap](https://github.com/combust/mleap)\n",
    "Apache Spark is best suited batch processing workloads. In order to use the Spark ML model we trained for low latency inference, we need to use the MLeap library to serialize it to an MLeap bundle and later use the [SageMaker SparkML Serving](https://github.com/aws/sagemaker-sparkml-serving-container) to perform realtime and batch inference. \n",
    "\n",
    "By using the `SerializeToBundle()` method from MLeap in the script, we are serializing the ML Pipeline into an MLeap bundle and uploading to S3 in `tar.gz` format as SageMaker expects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Uploading the code and other dependencies to S3 for AWS Glue\n",
    "Unlike SageMaker, in order to run your code in AWS Glue, we do not need to prepare a Docker image. We can upload your code and dependencies directly to S3 and pass those locations while invoking the Glue job."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload the featurizer script to S3\n",
    "We will be uploading the `dbpedia_processing.py` script to S3 now so that Glue can use it to run the PySpark job. You can replace it with your own script if needed. If your code has multiple files, you need to zip those files and upload to S3 instead of uploading a single file like it's being done here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "script_location = sess.upload_data(path='dbpedia_processing.py', bucket=default_bucket, key_prefix='codes')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload MLeap dependencies to S3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For our job, we will also have to pass MLeap dependencies to Glue.MLeap is an additional library we are using which does not come bundled with default Spark.\n",
    "Similar to most of the packages in the Spark ecosystem, MLeap is also implemented as a Scala package with a front-end wrapper written in Python so that it can be used from PySpark. We need to make sure that the MLeap Python library as well as the JAR is available within the Glue job environment. In the following cell, we will download the MLeap Python dependency & JAR from a SageMaker hosted bucket and upload to the S3 bucket we created above in your account. \n",
    "If you are using some other Python libraries like `nltk` in your code, you need to download the wheel file from PyPI and upload to S3 in the same way. At this point, Glue only supports passing pure Python libraries in this way (e.g. you can not pass `Pandas` or `OpenCV`). However you can use `NumPy` & `SciPy` without having to pass these as packages because these are pre-installed in the Glue environment. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://s3-us-west-2.amazonaws.com/sparkml-mleap/0.9.6/python/python.zip\n",
    "!wget https://s3-us-west-2.amazonaws.com/sparkml-mleap/0.9.6/jar/mleap_spark_assembly.jar    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "python_dep_location = sess.upload_data(path='python.zip', bucket=default_bucket, key_prefix='dependencies/python')\n",
    "jar_dep_location = sess.upload_data(path='mleap_spark_assembly.jar', bucket=default_bucket, key_prefix='dependencies/jar')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining output locations for the data and model\n",
    "Next we define the output location where the transformed dataset should be uploaded. We are also specifying a model location where the MLeap serialized model would be updated. This locations should be consumed as part of the Spark script using `getResolvedOptions` method of AWS Glue library (see `dbpedia_processing.py` for details).\n",
    "By designing our code in this way, we can re-use these variables as part of the SageMaker training job (details below)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import gmtime, strftime\n",
    "import time\n",
    "\n",
    "timestamp_prefix = strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "\n",
    "# Input location of the data, We uploaded our train.csv file to input key previously\n",
    "s3_input_bucket = default_bucket\n",
    "s3_input_key_prefix = 'input/dbpedia'\n",
    "\n",
    "# Output location of the data. The input data will be split, transformed, and \n",
    "# uploaded to output/train and output/validation\n",
    "s3_output_bucket = default_bucket\n",
    "s3_output_key_prefix = timestamp_prefix + '/dbpedia'\n",
    "\n",
    "# the MLeap serialized SparkML model will be uploaded to output/mleap\n",
    "s3_model_bucket = default_bucket\n",
    "s3_model_key_prefix = s3_output_key_prefix + '/mleap'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calling Glue APIs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we'll be creating Glue client via Boto so that we can invoke the `create_job` API of Glue. `create_job` API will create a job definition which can be used to execute your jobs in Glue. The job definition created here is mutable. While creating the job, we are also passing the code location as well as the dependencies location to Glue.\n",
    "\n",
    "`AllocatedCapacity` parameter controls the hardware resources that Glue will use to execute this job. It is measures in units of `DPU`. For more information on `DPU`, please see [here](https://docs.aws.amazon.com/glue/latest/dg/add-job.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "glue_client = boto_session.client('glue')\n",
    "job_name = 'sparkml-dbpedia-' + timestamp_prefix\n",
    "response = glue_client.create_job(\n",
    "    Name=job_name,\n",
    "    Description='PySpark job to featurize the DBPedia dataset',\n",
    "    Role=role, # you can pass your existing AWS Glue role here if you have used Glue before\n",
    "    ExecutionProperty={\n",
    "        'MaxConcurrentRuns': 1\n",
    "    },\n",
    "    Command={\n",
    "        'Name': 'glueetl',\n",
    "        'ScriptLocation': script_location\n",
    "    },\n",
    "    DefaultArguments={\n",
    "        '--job-language': 'python',\n",
    "        '--extra-jars' : jar_dep_location,\n",
    "        '--extra-py-files': python_dep_location\n",
    "    },\n",
    "    AllocatedCapacity=10,\n",
    "    Timeout=60,\n",
    ")\n",
    "glue_job_name = response['Name']\n",
    "print(glue_job_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The aforementioned job will be executed now by calling `start_job_run` API. This API creates an immutable run/execution corresponding to the job definition created above. We will require the `job_run_id` for the particular job execution to check for status. We'll pass the data and model locations as part of the job execution parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_run_id = glue_client.start_job_run(JobName=job_name,\n",
    "                                       Arguments = {\n",
    "                                        '--S3_INPUT_BUCKET': s3_input_bucket,\n",
    "                                        '--S3_INPUT_KEY_PREFIX': s3_input_key_prefix,\n",
    "                                        '--S3_OUTPUT_BUCKET': s3_output_bucket,\n",
    "                                        '--S3_OUTPUT_KEY_PREFIX': s3_output_key_prefix,\n",
    "                                        '--S3_MODEL_BUCKET': s3_model_bucket,\n",
    "                                        '--S3_MODEL_KEY_PREFIX': s3_model_key_prefix\n",
    "                                       })['JobRunId']\n",
    "print(job_run_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Checking Glue job status"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will check for the job status to see if it has `succeeded`, `failed` or `stopped`. Once the job is succeeded, we have the transformed data into S3 in CSV format which we can use with `BlazingText` for training. If the job fails, you can go to [AWS Glue console](https://us-west-2.console.aws.amazon.com/glue/home), click on **Jobs** tab on the left, and from the page, click on this particular job and you will be able to find the CloudWatch logs (the link under **Logs**) link for these jobs which can help you to see what exactly went wrong in the `spark-submit` call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_run_status = glue_client.get_job_run(JobName=job_name,RunId=job_run_id)['JobRun']['JobRunState']\n",
    "while job_run_status not in ('FAILED', 'SUCCEEDED', 'STOPPED'):\n",
    "    job_run_status = glue_client.get_job_run(JobName=job_name,RunId=job_run_id)['JobRun']['JobRunState']\n",
    "    print (job_run_status)\n",
    "    time.sleep(30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using SageMaker BlazingText to train on the processed dataset produced by SparkML job"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will use SageMaker `BlazingText` algorithm to train a text classification model this dataset. We already know the S3 location where the preprocessed training data was uploaded as part of the Glue job."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We need to retrieve the BlazingText algorithm image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "\n",
    "training_image = get_image_uri(sess.boto_region_name, 'blazingtext', repo_version=\"latest\")\n",
    "print (training_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Next BlazingText model parameters and dataset details will be set properly\n",
    "We have parameterized the notebook so that the same data location which was used in the PySpark script can now be passed to `BlazingText` Estimator as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_train_data = 's3://{}/{}/{}'.format(s3_output_bucket, s3_output_key_prefix, 'train')\n",
    "s3_validation_data = 's3://{}/{}/{}'.format(s3_output_bucket, s3_output_key_prefix, 'validation')\n",
    "s3_output_location = 's3://{}/{}/{}'.format(s3_output_bucket, s3_output_key_prefix, 'bt_model')\n",
    "\n",
    "bt_model = sagemaker.estimator.Estimator(training_image,\n",
    "                                         role, \n",
    "                                         train_instance_count=1, \n",
    "                                         train_instance_type='ml.c4.xlarge',\n",
    "                                         train_volume_size = 20,\n",
    "                                         train_max_run = 3600,\n",
    "                                         input_mode= 'File',\n",
    "                                         output_path=s3_output_location,\n",
    "                                         sagemaker_session=sess)\n",
    "\n",
    "bt_model.set_hyperparameters(mode=\"supervised\",\n",
    "                            epochs=10,\n",
    "                            min_count=2,\n",
    "                            learning_rate=0.05,\n",
    "                            vector_dim=10,\n",
    "                            early_stopping=True,\n",
    "                            patience=4,\n",
    "                            min_epochs=5,\n",
    "                            word_ngrams=2)\n",
    "\n",
    "train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', \n",
    "                        content_type='text/plain', s3_data_type='S3Prefix')\n",
    "validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', \n",
    "                             content_type='text/plain', s3_data_type='S3Prefix')\n",
    "\n",
    "data_channels = {'train': train_data, 'validation': validation_data}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finally BlazingText training will be performed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bt_model.fit(inputs=data_channels, logs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building an Inference Pipeline consisting of SparkML & BlazingText models for a realtime inference endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we will proceed with deploying the models in SageMaker to create an Inference Pipeline. You can create an Inference Pipeline with upto five containers.\n",
    "\n",
    "Deploying a model in SageMaker requires two components:\n",
    "\n",
    "* Docker image residing in ECR.\n",
    "* Model artifacts residing in S3.\n",
    "\n",
    "**SparkML**\n",
    "\n",
    "For SparkML, Docker image for MLeap based SparkML serving is provided by SageMaker team. For more information on this, please see [SageMaker SparkML Serving](https://github.com/aws/sagemaker-sparkml-serving-container). MLeap serialized SparkML model was uploaded to S3 as part of the SparkML job we executed in AWS Glue.\n",
    "\n",
    "**BlazingText**\n",
    "\n",
    "For BlazingText, we will use the same Docker image we used for training. The model artifacts for BlazingText was uploaded as part of the training job we just ran."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the Endpoint with both containers\n",
    "Next we'll create a SageMaker inference endpoint with both the `sagemaker-sparkml-serving` & `BlazingText` containers. For this, we will first create a `PipelineModel` which will consist of both the `SparkML` model as well as `BlazingText` model in the right sequence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Passing the schema of the payload via environment variable\n",
    "SparkML serving container needs to know the schema of the request that'll be passed to it while calling the `predict` method. In order to alleviate the pain of not having to pass the schema with every request, `sagemaker-sparkml-serving` allows you to pass it via an environment variable while creating the model definitions. This schema definition will be required in our next step for creating a model.\n",
    "\n",
    "We will see later that you can overwrite this schema on a per request basis by passing it as part of the individual request payload as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "schema = {\n",
    "    \"input\": [\n",
    "        {\n",
    "            \"name\": \"abstract\",\n",
    "            \"type\": \"string\"\n",
    "        }\n",
    "    ],\n",
    "    \"output\": \n",
    "        {\n",
    "            \"name\": \"tokenized_abstract\",\n",
    "            \"type\": \"string\",\n",
    "            \"struct\": \"array\"\n",
    "        }\n",
    "}\n",
    "schema_json = json.dumps(schema)\n",
    "print(schema_json)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a `PipelineModel` which comprises of the SparkML and BlazingText model in the right order\n",
    "\n",
    "Next we'll create a SageMaker `PipelineModel` with SparkML and BlazingText.The `PipelineModel` will ensure that both the containers get deployed behind a single API endpoint in the correct order. The same model would later be used for Batch Transform as well to ensure that a single job is sufficient to do prediction against the Pipeline. \n",
    "\n",
    "Here, during the `Model` creation for SparkML, we will pass the schema definition that we built in the previous cell."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Controlling the output format from `sagemaker-sparkml-serving` to the next container\n",
    "\n",
    "By default, `sagemaker-sparkml-serving` returns an output in `CSV` format. However, BlazingText does not understand CSV format and it supports a different format. \n",
    "\n",
    "In order for the `sagemaker-sparkml-serving` to emit the output with the right format, we need to pass a second environment variable `SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT` with the value `application/jsonlines;data=text` to ensure that `sagemaker-sparkml-serving` container emits response in the proper format which BlazingText can parse.\n",
    "\n",
    "For more information on different output formats `sagemaker-sparkml-serving` supports, please check the documentation pointed above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.model import Model\n",
    "from sagemaker.pipeline import PipelineModel\n",
    "from sagemaker.sparkml.model import SparkMLModel\n",
    "\n",
    "sparkml_data = 's3://{}/{}/{}'.format(s3_model_bucket, s3_model_key_prefix, 'model.tar.gz')\n",
    "# passing the schema defined above by using an environment variable that sagemaker-sparkml-serving understands\n",
    "sparkml_model = SparkMLModel(model_data=sparkml_data,\n",
    "                             env={'SAGEMAKER_SPARKML_SCHEMA' : schema_json, \n",
    "                                  'SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT': \"application/jsonlines;data=text\"})\n",
    "bt_model = Model(model_data=bt_model.model_data, image=training_image)\n",
    "\n",
    "model_name = 'inference-pipeline-' + timestamp_prefix\n",
    "sm_model = PipelineModel(name=model_name, role=role, models=[sparkml_model, bt_model])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deploying the `PipelineModel` to an endpoint for realtime inference\n",
    "Next we will deploy the model we just created with the `deploy()` method to start an inference endpoint and we will send some requests to the endpoint to verify that it works as expected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoint_name = 'inference-pipeline-ep-' + timestamp_prefix\n",
    "sm_model.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge', endpoint_name=endpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Invoking the newly created inference endpoint with a payload to transform the data\n",
    "Now we will invoke the endpoint with a valid payload that `sagemaker-sparkml-serving` can recognize. There are three ways in which input payload can be passed to the request:\n",
    "\n",
    "* Pass it as a valid CSV string. In this case, the schema passed via the environment variable will be used to determine the schema. For CSV format, every column in the input has to be a basic datatype (e.g. int, double, string) and it can not be a Spark `Array` or `Vector`.\n",
    "\n",
    "* Pass it as a valid JSON string. In this case as well, the schema passed via the environment variable will be used to infer the schema. With JSON format, every column in the input can be a basic datatype or a Spark `Vector` or `Array` provided that the corresponding entry in the schema mentions the correct value.\n",
    "\n",
    "* Pass the request in JSON format along with the schema and the data. In this case, the schema passed in the payload will take precedence over the one passed via the environment variable (if any)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Passing the payload in CSV format\n",
    "We will first see how the payload can be passed to the endpoint in CSV format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.predictor import json_serializer, csv_serializer, json_deserializer, RealTimePredictor\n",
    "from sagemaker.content_types import CONTENT_TYPE_CSV, CONTENT_TYPE_JSON\n",
    "payload = \"Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft.\"\n",
    "predictor = RealTimePredictor(endpoint=endpoint_name, sagemaker_session=sess, serializer=csv_serializer,\n",
    "                                content_type=CONTENT_TYPE_CSV, accept='application/jsonlines')\n",
    "print(predictor.predict(payload))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Passing the payload in JSON format\n",
    "We will now pass a different payload in JSON format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "payload = {\"data\": [\"Berwick secondary college is situated in the outer melbourne metropolitan suburb of berwick .\"]}\n",
    "predictor = RealTimePredictor(endpoint=endpoint_name, sagemaker_session=sess, serializer=json_serializer,\n",
    "                                content_type=CONTENT_TYPE_JSON)\n",
    "\n",
    "print(predictor.predict(payload))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Optional] Deleting the Endpoint\n",
    "If you do not plan to use this endpoint, then it is a good practice to delete the endpoint so that you do not incur the cost of running it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sm_client = boto_session.client('sagemaker')\n",
    "sm_client.delete_endpoint(EndpointName=endpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building an Inference Pipeline consisting of SparkML & BlazingText models for a single Batch Transform job\n",
    "SageMaker Batch Transform also supports chaining multiple containers together when deploying an Inference Pipeline and performing a single Batch Transform job to transform your data for a batch use-case similar to the real-time use-case we have seen above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing data for Batch Transform\n",
    "Batch Transform requires data in the same format described above, with one CSV or JSON being per line. For this notebook, SageMaker team has created a sample input in CSV format which Batch Transform can process. The input is a simple CSV file with one input string per line.\n",
    "\n",
    "Next we will download a sample of this data from one of the SageMaker buckets (named `batch_input_dbpedia.csv`) and upload to your S3 bucket. We will also inspect first five rows of the data post downloading."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://s3-us-west-2.amazonaws.com/sparkml-mleap/data/batch_input_dbpedia.csv\n",
    "!printf \"\\n\\nShowing first two lines\\n\\n\"    \n",
    "!head -n 3 batch_input_dbpedia.csv\n",
    "!printf \"\\n\\nAs we can see, it is just one input string per line.\\n\\n\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_input_loc = sess.upload_data(path='batch_input_dbpedia.csv', bucket=default_bucket, key_prefix='batch')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Invoking the Transform API to create a Batch Transform job\n",
    "Next we will create a Batch Transform job using the `Transformer` class from Python SDK to create a Batch Transform job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_data_path = 's3://{}/{}/{}'.format(default_bucket, 'batch', 'batch_input_dbpedia.csv')\n",
    "output_data_path = 's3://{}/{}/{}'.format(default_bucket, 'batch_output/dbpedia', timestamp_prefix)\n",
    "transformer = sagemaker.transformer.Transformer(\n",
    "    model_name = model_name,\n",
    "    instance_count = 1,\n",
    "    instance_type = 'ml.m4.xlarge',\n",
    "    strategy = 'SingleRecord',\n",
    "    assemble_with = 'Line',\n",
    "    output_path = output_data_path,\n",
    "    base_transform_job_name='serial-inference-batch',\n",
    "    sagemaker_session=sess,\n",
    "    accept = CONTENT_TYPE_CSV\n",
    ")\n",
    "transformer.transform(data = input_data_path, \n",
    "                        content_type = CONTENT_TYPE_CSV, \n",
    "                        split_type = 'Line')\n",
    "transformer.wait()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python2",
   "language": "python",
   "name": "conda_python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
